{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "35da8276",
   "metadata": {},
   "source": [
    "# Appendix Code\n",
    "\n",
    "We begin with some imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46328226",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np                                            \n",
    "import matplotlib.pyplot as plt       \n",
    "import quantecon_book_networks                  \n",
    "from mpl_toolkits.mplot3d.axes3d import Axes3D, proj3d  \n",
    "from matplotlib import cm          \n",
    "export_figures = False\n",
    "quantecon_book_networks.config(\"matplotlib\")                     "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "483e031e",
   "metadata": {},
   "source": [
    "## Functions\n",
    "\n",
    "### One-to-one and onto functions on $(0,1)$\n",
    "\n",
    "We start by defining the domain and our one-to-one and onto function examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b617e66",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(0, 1, 100)\n",
    "titles = 'one-to-one', 'onto'\n",
    "labels = '$f(x)=1/2 + x/2$', '$f(x)=4x(1-x)$'\n",
    "funcs = lambda x: 1/2 + x/2, lambda x: 4 * x * (1-x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40b1bb2d",
   "metadata": {},
   "source": [
    "The figure can now be produced as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ba8e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "for f, ax, lb, ti in zip(funcs, axes, labels, titles):\n",
    "    ax.set_xlim(0, 1)\n",
    "    ax.set_ylim(0, 1.01)\n",
    "    ax.plot(x, f(x), label=lb)\n",
    "    ax.set_title(ti, fontsize=12)\n",
    "    ax.legend(loc='lower center', fontsize=12, frameon=False)\n",
    "    ax.set_xticks((0, 1))\n",
    "    ax.set_yticks((0, 1))\n",
    "\n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/func_types_1.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe7452ac",
   "metadata": {},
   "source": [
    "### Some functions are bijections\n",
    "\n",
    "This figure can be produced in a similar manner to 6.1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5907691a",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(0, 1, 100)\n",
    "titles = 'constant', 'bijection'\n",
    "labels = '$f(x)=1/2$', '$f(x)=1-x$'\n",
    "funcs = lambda x: 1/2 + 0 * x, lambda x: 1-x\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "for f, ax, lb, ti in zip(funcs, axes, labels, titles):\n",
    "    ax.set_xlim(0, 1)\n",
    "    ax.set_ylim(0, 1.01)\n",
    "    ax.plot(x, f(x), label=lb)\n",
    "    ax.set_title(ti, fontsize=12)\n",
    "    ax.legend(loc='lower center', fontsize=12, frameon=False)\n",
    "    ax.set_xticks((0, 1))\n",
    "    ax.set_yticks((0, 1))\n",
    "    \n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/func_types_2.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbfde515",
   "metadata": {},
   "source": [
    "## Fixed Points\n",
    "\n",
    "### Graph and fixed points of $G \\colon x \\mapsto 2.125/(1 + x^{-4})$\n",
    "\n",
    "We begin by defining the domain and the function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3110e5cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "xmin, xmax = 0.0000001, 2\n",
    "xgrid = np.linspace(xmin, xmax, 200)\n",
    "g = lambda x: 2.125 / (1 + x**(-4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6532632d",
   "metadata": {},
   "source": [
    "Next we define our fixed points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b9ef820",
   "metadata": {},
   "outputs": [],
   "source": [
    "fps_labels = ('$x_\\ell$', '$x_m$', '$x_h$' )\n",
    "fps = (0.01, 0.94, 1.98)\n",
    "coords = ((40, 80), (40, -40), (-40, -80))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87f4c5c2",
   "metadata": {},
   "source": [
    "Finally we can produce the figure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8f4567d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(6.5, 6))\n",
    "\n",
    "ax.set_xlim(xmin, xmax)\n",
    "ax.set_ylim(xmin, xmax)\n",
    "\n",
    "ax.plot(xgrid, g(xgrid), 'b-', lw=2, alpha=0.6, label='$G$')\n",
    "ax.plot(xgrid, xgrid, 'k-', lw=1, alpha=0.7, label='$45^o$')\n",
    "\n",
    "ax.legend(fontsize=14)\n",
    "\n",
    "ax.plot(fps, fps, 'ro', ms=8, alpha=0.6)\n",
    "\n",
    "for (fp, lb, coord) in zip(fps, fps_labels, coords):\n",
    "    ax.annotate(lb, \n",
    "             xy=(fp, fp),\n",
    "             xycoords='data',\n",
    "             xytext=coord,\n",
    "             textcoords='offset points',\n",
    "             fontsize=16,\n",
    "             arrowprops=dict(arrowstyle=\"->\"))\n",
    "\n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/three_fixed_points.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b47191",
   "metadata": {},
   "source": [
    "## Complex Numbers\n",
    "\n",
    "### The complex number $(a, b) = r e^{i \\phi}$\n",
    "\n",
    "We start by abbreviating some useful values and functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee845f66",
   "metadata": {},
   "outputs": [],
   "source": [
    "π = np.pi\n",
    "zeros = np.zeros\n",
    "ones = np.ones\n",
    "fs = 18"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a40f91e",
   "metadata": {},
   "source": [
    "Next we set our parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa26afd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "r = 2\n",
    "φ = π/3\n",
    "x = r * np.cos(φ)\n",
    "x_range = np.linspace(0, x, 1000)\n",
    "φ_range = np.linspace(0, φ, 1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d55d19ec",
   "metadata": {},
   "source": [
    "Finally we produce the plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7036969e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(7, 7))\n",
    "ax = plt.subplot(111, projection='polar')\n",
    "\n",
    "ax.plot((0, φ), (0, r), marker='o', color='b', alpha=0.5)          # Plot r\n",
    "ax.plot(zeros(x_range.shape), x_range, color='b', alpha=0.5)       # Plot x\n",
    "ax.plot(φ_range, x / np.cos(φ_range), color='b', alpha=0.5)        # Plot y\n",
    "ax.plot(φ_range, ones(φ_range.shape) * 0.15, color='green')  # Plot φ\n",
    "\n",
    "ax.margins(0) # Let the plot starts at origin\n",
    "\n",
    "ax.set_rmax(2)\n",
    "ax.set_rticks((1, 2))  # Less radial ticks\n",
    "ax.set_rlabel_position(-88.5)    # Get radial labels away from plotted line\n",
    "\n",
    "ax.text(φ, r+0.04 , r'$(a, b) = (1, \\sqrt{3})$', fontsize=fs)   # Label z\n",
    "ax.text(φ+0.4, 1 , '$r = 2$', fontsize=fs)                             # Label r\n",
    "ax.text(0-0.4, 0.5, '$1$', fontsize=fs)                            # Label x\n",
    "ax.text(0.5, 1.2, '$\\sqrt{3}$', fontsize=fs)                      # Label y\n",
    "ax.text(0.3, 0.25, '$\\\\varphi = \\\\pi/3$', fontsize=fs)                   # Label θ\n",
    "\n",
    "xT=plt.xticks()[0]\n",
    "xL=['0',\n",
    "    r'$\\frac{\\pi}{4}$',\n",
    "    r'$\\frac{\\pi}{2}$',\n",
    "    r'$\\frac{3\\pi}{4}$',\n",
    "    r'$\\pi$',\n",
    "    r'$\\frac{5\\pi}{4}$',\n",
    "    r'$\\frac{3\\pi}{2}$',\n",
    "    r'$\\frac{7\\pi}{4}$']\n",
    "\n",
    "plt.xticks(xT, xL, fontsize=fs+2)\n",
    "ax.grid(True)\n",
    "\n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/complex_number.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "418ac96b",
   "metadata": {},
   "source": [
    "## Convergence\n",
    "\n",
    "### Convergence of a sequence to the origin in $\\mathbb{R}^3$\n",
    "\n",
    "We define our transformation matrix, initial point, and number of iterations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebbf6747",
   "metadata": {},
   "outputs": [],
   "source": [
    "θ = 0.1\n",
    "A = ((np.cos(θ), - np.sin(θ), 0.0001),\n",
    "     (np.sin(θ),   np.cos(θ), 0.001),\n",
    "     (np.sin(θ),   np.cos(θ), 1))\n",
    "\n",
    "A = 0.98 * np.array(A)\n",
    "p = np.array((1, 1, 1))\n",
    "n = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "815fed7d",
   "metadata": {},
   "source": [
    "Now we can produce the plot by repeatedly transforming our point with the transformation matrix and plotting each resulting point."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e5a9a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(6, 6))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "ax.view_init(elev=20, azim=-40)\n",
    "\n",
    "ax.set_xlim((-1.5, 1.5))\n",
    "ax.set_ylim((-1.5, 1.5))\n",
    "ax.set_xticks((-1,0,1))\n",
    "ax.set_yticks((-1,0,1))\n",
    "\n",
    "for i in range(n):\n",
    "    x, y, z = p\n",
    "    ax.plot([x], [y], [z], 'o', ms=4, color=cm.jet_r(i / n))\n",
    "    p = A @ p\n",
    "    \n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/euclidean_convergence_1.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8804299c",
   "metadata": {},
   "source": [
    "## Linear Algebra\n",
    "\n",
    "### The span of vectors $u$, $v$, $w$ in $\\mathbb{R}$\n",
    "\n",
    "We begin by importing the `FancyArrowPatch` class and extending it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07b47dce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.patches import FancyArrowPatch\n",
    "\n",
    "class Arrow3D(FancyArrowPatch):\n",
    "    def __init__(self, xs, ys, zs, *args, **kwargs):\n",
    "        FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)\n",
    "        self._verts3d = xs, ys, zs\n",
    "\n",
    "    def do_3d_projection(self, renderer=None):\n",
    "        xs3d, ys3d, zs3d = self._verts3d\n",
    "        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)\n",
    "        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))\n",
    "\n",
    "        return np.min(zs)\n",
    "\n",
    "    def draw(self, renderer):\n",
    "        xs3d, ys3d, zs3d = self._verts3d\n",
    "        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)\n",
    "        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))\n",
    "        FancyArrowPatch.draw(self, renderer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a087a52",
   "metadata": {},
   "source": [
    "Next we generate our vectors $u$, $v$, $w$, ensuring linear dependence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ced11ab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "α, β = 0.2, 0.1\n",
    "def f(x, y):\n",
    "    return α * x + β * y\n",
    "\n",
    "# Vector locations, by coordinate\n",
    "x_coords = np.array((3, 3, -3.5))\n",
    "y_coords = np.array((4, -4, 3.0))\n",
    "z_coords = f(x_coords, y_coords)\n",
    "\n",
    "vecs = [np.array((x, y, z)) for x, y, z in zip(x_coords, y_coords, z_coords)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "165912ee",
   "metadata": {},
   "source": [
    "Next we define the spanning plane."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97870984",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_min, x_max = -5, 5\n",
    "y_min, y_max = -5, 5\n",
    "\n",
    "grid_size = 20\n",
    "xr2 = np.linspace(x_min, x_max, grid_size)\n",
    "yr2 = np.linspace(y_min, y_max, grid_size)\n",
    "x2, y2 = np.meshgrid(xr2, yr2)\n",
    "z2 = f(x2, y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be62269c",
   "metadata": {},
   "source": [
    "Finally we generate the plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b4eff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(12, 7))\n",
    "ax = plt.axes(projection ='3d')\n",
    "ax.view_init(elev=10., azim=-80)\n",
    "\n",
    "ax.set(xlim=(x_min, x_max), \n",
    "       ylim=(x_min, x_max), \n",
    "       zlim=(x_min, x_max),\n",
    "       xticks=(0,), yticks=(0,), zticks=(0,))\n",
    "\n",
    "# Draw the axes\n",
    "gs = 3\n",
    "z = np.linspace(x_min, x_max, gs)\n",
    "x = np.zeros(gs)\n",
    "y = np.zeros(gs)\n",
    "ax.plot(x, y, z, 'k-', lw=2, alpha=0.5)\n",
    "ax.plot(z, x, y, 'k-', lw=2, alpha=0.5)\n",
    "ax.plot(y, z, x, 'k-', lw=2, alpha=0.5)\n",
    "\n",
    "# Draw the vectors\n",
    "for v in vecs:\n",
    "    a = Arrow3D([0, v[0]], \n",
    "                [0, v[1]], \n",
    "                [0, v[2]], \n",
    "                mutation_scale=20, \n",
    "                lw=1, \n",
    "                arrowstyle=\"-|>\", \n",
    "                color=\"b\")\n",
    "    ax.add_artist(a)\n",
    "\n",
    "\n",
    "for v, label in zip(vecs, ('u', 'v', 'w')):\n",
    "    v = v * 1.1\n",
    "    ax.text(v[0], v[1], v[2], \n",
    "            f'${label}$', \n",
    "            fontsize=14)\n",
    "\n",
    "# Draw the plane\n",
    "grid_size = 20\n",
    "xr2 = np.linspace(x_min, x_max, grid_size)\n",
    "yr2 = np.linspace(y_min, y_max, grid_size)\n",
    "x2, y2 = np.meshgrid(xr2, yr2)\n",
    "z2 = f(x2, y2)\n",
    "\n",
    "ax.plot_surface(x2, y2, z2, rstride=1, cstride=1, cmap=cm.jet,\n",
    "                linewidth=0, antialiased=True, alpha=0.2)\n",
    "\n",
    "\n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/span1.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b87b7088",
   "metadata": {},
   "source": [
    "## Linear Maps Are Matrices\n",
    "\n",
    "### Equivalence of the onto and one-to-one properties (for linear maps)\n",
    "\n",
    "This plot is produced similarly to Figures 6.1 and 6.2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a243b6e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "\n",
    "x = np.linspace(-2, 2, 10)\n",
    "\n",
    "titles = 'non-bijection', 'bijection'\n",
    "labels = '$f(x)=0 x$', '$f(x)=0.5 x$'\n",
    "funcs = lambda x: 0*x, lambda x: 0.5 * x\n",
    "\n",
    "for ax, f, lb, ti in zip(axes, funcs, labels, titles):\n",
    "\n",
    "    # Set the axes through the origin\n",
    "    for spine in ['left', 'bottom']:\n",
    "        ax.spines[spine].set_position('zero')\n",
    "    for spine in ['right', 'top']:\n",
    "        ax.spines[spine].set_color('none')\n",
    "    ax.set_yticks((-1,  1))\n",
    "    ax.set_ylim((-1, 1))\n",
    "    ax.set_xlim((-1, 1))\n",
    "    ax.set_xticks((-1, 1))\n",
    "    y = f(x)\n",
    "    ax.plot(x, y, '-', linewidth=4, label=lb, alpha=0.6)\n",
    "    ax.text(-0.8, 0.5, ti, fontsize=14)\n",
    "    ax.legend(loc='lower right', fontsize=12)\n",
    "    \n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/func_types_3.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3aa06c3",
   "metadata": {},
   "source": [
    "## Convexity and Polyhedra\n",
    "\n",
    "### A polyhedron $P$ represented as intersecting halfspaces\n",
    "\n",
    "Inequalities are of the form\n",
    "\n",
    "$$ a x + b y \\leq c $$\n",
    "\n",
    "To plot the halfspace we plot the line\n",
    "\n",
    "$$ y = c/b - a/b x $$\n",
    "\n",
    "and then fill in the halfspace using `fill_between` on points $x, y, \\hat y$,\n",
    "where $\\hat y$ is either `y_min` or `y_max`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae4e8b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "plt.axis('off')\n",
    "\n",
    "x = np.linspace(-10, 14, 200)\n",
    "y_min, y_max = -2, 3\n",
    "\n",
    "a1, b1, c1 = 1.0, 8.0, -5.0\n",
    "y = c1 / b1 - (a1 / b1) * x\n",
    "ax.plot(x, y, label='$a_1 x_1 + b_1 x_2 = c_1$')\n",
    "ax.fill_between(x, y, y_max, alpha=0.2)\n",
    "\n",
    "\n",
    "a2, b2, c2 = 0.5, 0.75, 1.5\n",
    "y = c2 / b2 - (a2 / b2) * x\n",
    "ax.plot(x, y, label='$a_2 x_1 + b_2 x_2 = c_2$')\n",
    "ax.fill_between(x, y, y_min, alpha=0.2)\n",
    "\n",
    "\n",
    "a3, b3, c3 = -1.0, 1.0, 1.5\n",
    "y = c3 / b3 - (a3 / b3) * x\n",
    "ax.plot(x, y, label='$a_3 x_1 + b_3 x_2 = c_3$')\n",
    "ax.fill_between(x, y, y_min, alpha=0.2)\n",
    "\n",
    "ax.plot((0.23,), (1.82,), 'ko')\n",
    "ax.plot((-1.95,), (-0.4,), 'ko')\n",
    "ax.plot((4.8,), (-1.2,), 'ko')\n",
    "\n",
    "\n",
    "ax.annotate('$P$', xy=(0, 0), fontsize=12)\n",
    "\n",
    "ax.set_ylim(y_min, y_max)\n",
    "if export_figures:\n",
    "    plt.savefig(\"figures/polyhedron1.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45b13eed",
   "metadata": {},
   "source": [
    "## Saddle Points and Duality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53bdce43",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8.5, 6))\n",
    "\n",
    "## Top left Plot\n",
    "\n",
    "ax = fig.add_subplot(221, projection='3d')\n",
    "\n",
    "plot_args = {'rstride': 1, 'cstride': 1, 'cmap':\"viridis\",\n",
    "             'linewidth': 0.4, 'antialiased': True, \"alpha\":0.75,\n",
    "             'vmin': -1, 'vmax': 1}\n",
    "\n",
    "x, y = np.mgrid[-1:1:31j, -1:1:31j]\n",
    "z = x**2 - y**2\n",
    "\n",
    "ax.plot_surface(x, y, z, **plot_args)\n",
    "\n",
    "ax.view_init(azim=245, elev=20)\n",
    "\n",
    "ax.set_xlim(-1, 1)\n",
    "ax.set_ylim(-1, 1)\n",
    "ax.set_zlim(-1, 1)\n",
    "\n",
    "ax.set_xticks([0])\n",
    "ax.set_xticklabels([r\"$x^*$\"], fontsize=16)\n",
    "ax.set_yticks([0])\n",
    "ax.set_yticklabels([r\"$\\theta^*$\"], fontsize=16)\n",
    "\n",
    "\n",
    "ax.set_zticks([])\n",
    "ax.set_zticklabels([])\n",
    "\n",
    "ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) \n",
    "ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) \n",
    "ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "\n",
    "ax.set_zlabel(\"$L(x,\\\\theta)$\", fontsize=14)\n",
    "\n",
    "## Top Right Plot\n",
    "\n",
    "ax = fig.add_subplot(222)\n",
    "\n",
    "\n",
    "plot_args = {'cmap':\"viridis\", 'antialiased': True, \"alpha\":0.6,\n",
    "             'vmin': -1, 'vmax': 1}\n",
    "\n",
    "x, y = np.mgrid[-1:1:31j, -1:1:31j]\n",
    "z = x**2 - y**2\n",
    "\n",
    "ax.contourf(x, y, z, **plot_args)\n",
    "\n",
    "ax.plot([0], [0], 'ko')\n",
    "\n",
    "ax.set_xlim(-1, 1)\n",
    "ax.set_ylim(-1, 1)\n",
    "\n",
    "plt.xticks([ 0],\n",
    "           [r\"$x^*$\"], fontsize=16)\n",
    "\n",
    "plt.yticks([0],\n",
    "           [r\"$\\theta^*$\"], fontsize=16)\n",
    "\n",
    "ax.hlines(0, -1, 1, color='k', ls='-', lw=1)\n",
    "ax.vlines(0, -1, 1, color='k', ls='-', lw=1)\n",
    "\n",
    "coords=(-35, 30)\n",
    "ax.annotate(r'$L(x, \\theta^*)$', \n",
    "             xy=(-0.5, 0),  \n",
    "             xycoords=\"data\",\n",
    "             xytext=coords,\n",
    "             textcoords=\"offset points\",\n",
    "             fontsize=12,\n",
    "             arrowprops={\"arrowstyle\" : \"->\"})\n",
    "\n",
    "coords=(35, 30)\n",
    "ax.annotate(r'$L(x^*, \\theta)$', \n",
    "             xy=(0, 0.25),  \n",
    "             xycoords=\"data\",\n",
    "             xytext=coords,\n",
    "             textcoords=\"offset points\",\n",
    "             fontsize=12,\n",
    "             arrowprops={\"arrowstyle\" : \"->\"})\n",
    "\n",
    "## Bottom Left Plot\n",
    "\n",
    "ax = fig.add_subplot(223)\n",
    "\n",
    "x = np.linspace(-1, 1, 100)\n",
    "ax.plot(x, -x**2, label='$\\\\theta \\mapsto L(x^*, \\\\theta)$')\n",
    "ax.set_ylim((-1, 1))\n",
    "ax.legend(fontsize=14)\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "\n",
    "## Bottom Right Plot\n",
    "\n",
    "ax = fig.add_subplot(224)\n",
    "\n",
    "x = np.linspace(-1, 1, 100)\n",
    "ax.plot(x, x**2, label='$x \\mapsto L(x, \\\\theta^*)$')\n",
    "ax.set_ylim((-1, 1))\n",
    "ax.legend(fontsize=14, loc='lower right')\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "\n",
    "if export_figures:\n",
    "\tplt.savefig(\"figures/saddle_1.pdf\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
